{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run main.py\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "!mkdir -p {DATA_DIR} {BERT_DIR} {MODEL_DIR}\n",
    "s3 = S3()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not exists(NEWS):\n",
    "    s3.download(S3_NEWS, NEWS)\n",
    "    s3.download(S3_FICTION, FICTION)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not exists(BERT_VOCAB):\n",
    "    s3.download(S3_BERT_VOCAB, BERT_VOCAB)\n",
    "    s3.download(S3_BERT_EMB, BERT_EMB)\n",
    "    s3.download(S3_BERT_ENCODER, BERT_ENCODER)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "words_vocab = BERTVocab.load(BERT_VOCAB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "markups = {}\n",
    "for path, name in [(NEWS, TEST), (FICTION, TRAIN)]:\n",
    "    lines = load_gz_lines(path)\n",
    "    items = parse_jl(lines)\n",
    "    items = log_progress(items, desc=path)\n",
    "    records = []\n",
    "    for item in items:\n",
    "        record = SyntaxMarkup.from_json(item)\n",
    "        records.append(record)\n",
    "    markups[name] = records\n",
    "\n",
    "rels = set()\n",
    "for name in [TEST, TRAIN]:\n",
    "    for markup in markups[name]:\n",
    "        for token in markup.tokens:\n",
    "            rels.add(token.rel)\n",
    "            \n",
    "rels = [PAD] + sorted(rels)\n",
    "rels_vocab = Vocab(rels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(SEED)\n",
    "seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encode = BERTSyntaxTrainEncoder(\n",
    "    words_vocab, rels_vocab,\n",
    "    seq_len=128, batch_size=16,\n",
    "    sort_size=10000\n",
    ")\n",
    "\n",
    "batches = {}\n",
    "for name in [TEST, TRAIN]:\n",
    "    records = encode(markups[name])\n",
    "    records = log_progress(records, desc=name)\n",
    "    batches[name] = [_.to(DEVICE) for _ in records]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = RuBERTConfig()\n",
    "emb = BERTEmbedding.from_config(config)\n",
    "encoder = BERTEncoder.from_config(config)\n",
    "head = BERTSyntaxHead(\n",
    "    input_dim=config.emb_dim,\n",
    "    hidden_dim=config.emb_dim // 2,\n",
    ")\n",
    "rel = BERTSyntaxRel(\n",
    "    input_dim=config.emb_dim,\n",
    "    hidden_dim=config.emb_dim // 2,\n",
    "    rel_dim=len(rels_vocab)\n",
    ")\n",
    "model = BERTSyntax(emb, encoder, head, rel)\n",
    "\n",
    "for param in emb.parameters():\n",
    "    param.requires_grad = False\n",
    "\n",
    "model.emb.load(BERT_EMB)\n",
    "model.encoder.load(BERT_ENCODER)\n",
    "model = model.to(DEVICE)\n",
    "\n",
    "criterion = masked_flatten_cross_entropy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "board = MultiBoard([\n",
    "    TensorBoard(BOARD_NAME, RUNS_DIR),\n",
    "    LogBoard()\n",
    "])\n",
    "boards = {\n",
    "    TRAIN: board.section(TRAIN_BOARD),\n",
    "    TEST: board.section(TEST_BOARD),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam([\n",
    "    dict(params=encoder.parameters(), lr=BERT_LR),\n",
    "    dict(params=chain(head.parameters(), rel.parameters()), lr=LR),\n",
    "])\n",
    "scheduler = optim.lr_scheduler.ExponentialLR(optimizer, LR_GAMMA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "meters = {\n",
    "    TRAIN: SyntaxScoreMeter(),\n",
    "    TEST: SyntaxScoreMeter()\n",
    "}\n",
    "\n",
    "for epoch in log_progress(range(EPOCHS)):\n",
    "    model.train()\n",
    "    for batch in log_progress(batches[TRAIN], leave=False):\n",
    "        optimizer.zero_grad()\n",
    "        batch = process_batch(model, criterion, batch)\n",
    "        batch.loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "        score = score_syntax_batch(batch)\n",
    "        meters[TRAIN].add(score)\n",
    "\n",
    "    meters[TRAIN].write(boards[TRAIN])\n",
    "    meters[TRAIN].reset()\n",
    "\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        for batch in log_progress(batches[TEST], leave=False, desc=TEST):\n",
    "            batch = process_batch(model, criterion, batch)\n",
    "            score = score_syntax_batch(batch)\n",
    "            meters[TEST].add(score)\n",
    "        meters[TEST].write(boards[TEST])\n",
    "        meters[TEST].reset()\n",
    "    \n",
    "    scheduler.step()\n",
    "    board.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# [2020-04-23 08:34:18]    0 0.4612 01_train/01_loss\n",
    "# [2020-04-23 08:34:18]    0 0.9047 01_train/02_uas\n",
    "# [2020-04-23 08:34:18]    0 0.8783 01_train/03_las\n",
    "# [2020-04-23 08:34:19]    0 0.3214 02_test/01_loss\n",
    "# [2020-04-23 08:34:19]    0 0.9618 02_test/02_uas\n",
    "# [2020-04-23 08:34:19]    0 0.9300 02_test/03_las\n",
    "# [2020-04-23 08:39:07]    1 0.2017 01_train/01_loss\n",
    "# [2020-04-23 08:39:07]    1 0.9511 01_train/02_uas\n",
    "# [2020-04-23 08:39:07]    1 0.9348 01_train/03_las\n",
    "\n",
    "# [2020-04-23 08:39:07]    1 0.3035 02_test/01_loss\n",
    "# [2020-04-23 08:39:07]    1 0.9635 02_test/02_uas\n",
    "# [2020-04-23 08:39:07]    1 0.9311 02_test/03_las"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model.encoder.dump(MODEL_ENCODER)\n",
    "# model.head.dump(MODEL_HEAD)\n",
    "# model.rel.dump(MODEL_REL)\n",
    "# rels_vocab.dump(RELS_VOCAB)\n",
    "        \n",
    "# s3.upload(MODEL_ENCODER, S3_MODEL_ENCODER)\n",
    "# s3.upload(MODEL_HEAD, S3_MODEL_HEAD)\n",
    "# s3.upload(MODEL_REL, S3_MODEL_REL)\n",
    "# s3.upload(RELS_VOCAB, S3_RELS_VOCAB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
